import cv2
import numpy as np
import os
import matplotlib.pyplot as plt
import pandas as pd
import glob
import math
      
def rec(u, v):
  ru = u[0]
  gu = u[1]
  bu = u[2]
  rv = v[0]
  gv = v[1]
  bv = v[2]
  cosines = (ru*rv+gu*gv+bu*bv)/(math.sqrt(ru*ru+gu*gu+bu*bu)*math.sqrt(rv*rv+gv*gv+bv*bv))
  angular_error = 180 * float(math.acos(max(-1, min(cosines, 1)))) / 3.141592653589793
  return angular_error
  
def rep(u, v):
  ru = u[0]
  gu = u[1]
  bu = u[2]
  rv = v[0]
  gv = v[1]
  bv = v[2]
  cosines = (ru/rv+gu/gv+bu/bv)/math.sqrt((math.pow(ru/rv, 2)+math.pow(gu/gv, 2)+math.pow(bu/bv, 2))*3)
  rep_error = 180 * float(math.acos(max(-1, min(cosines, 1)))) / 3.141592653589793
  return rep_error

def arc(target):
  rt = target[0]
  gt = target[1]
  bt = target[2]
  xt = math.acos((rt+gt+bt)/math.sqrt(3*(rt*rt+gt*gt+bt*bt)))/math.sqrt(math.pow(2*rt-gt-bt,2)+3*math.pow(gt-bt, 2))*(2*rt-gt-bt)
  yt = math.acos((rt+gt+bt)/math.sqrt(3*(rt*rt+gt*gt+bt*bt)))/math.sqrt(math.pow(2*rt-gt-bt,2)+3*math.pow(gt-bt, 2))*math.sqrt(3)*(gt-bt)
  return xt, yt

#MIC
def MIC(seq, ns):
  mic = []
  for i in range(0, ns-1):
    u = seq[i]
    v = seq[i+1]
    mic.append(rec(u, v))
  return max(mic)
      
#STD
def STD(seq, ns):
    std = 0
    x = []
    y = []
    for i in range(0, ns):
      xe, ye = arc(seq[i])
      x.append(xe)
      y.append(ye)
    xs = sum(x) / len(x)
    ys = sum(y) / len(y)
    x1 = 0
    y1 = 0
    for i in range(0, ns):
      x1 += (x[i] - xs)*(x[i] - xs)/ns
      y1 += (y[i] - ys)*(y[i] - ys)/ns
    std = std + x1 + y1
    std = math.sqrt(std)
    STD = 180 * std / 3.141592653589793
    return STD

dataset_device = ['HuaweiMate30', 'HuaweiP30PRO', 'iphone14pm', 'vivoiqooneo5', 'Xiaomi11PRO', 'Xiaomi13']
num_device = ['mate30', 'P30pro', 'iphonepm', 'vivo', 'xiaomi11pro', 'xiaomi13']
all = {'mate30': [], 'P30pro': [], 'iphonepm': [], 'vivo': [], 'xiaomi11pro': [], 'xiaomi13': []}

for i in range(3, 4):
  eval_data = {"file_names": [], "mic": [], "std": [], "ae": []}

  df = pd.read_csv("./TAWB/test/cta/logs/tccnet_cta_1700736880.647878/eval.csv")
  files = df['file_names'].values.tolist()
  predss = df['preds'].values.tolist()
  gtss = df['ground_truths'].values.tolist()

  x = []
  y = []
  z = []
  seqs = []

  test_path = './TAWB/dataset/CTA-Set/test_'+num_device[i]+'.npy'
  test_info = np.load(test_path, allow_pickle=True).item()
  seqs = test_info['id']

  for seq in seqs:
    aes = []
    names = []
    seq_all = []
    for j in range(len(files)):
      path_to_frame = str(files[j].split(',')[0])
      seq_num = str(path_to_frame.split('/')[-1])
      if str(seq) == seq_num and dataset_device[i] in files[j]:
          names.append(files[j])
          print(files[j])
          preds = predss[j][2:-2].split(' ')
          pred = []
          for p in preds:
              if p != '':
                  pred.append(float(p))
          gts = gtss[j][2:-2].split(' ')
          gt = []
          for g in gts:
            if g != '':
              gt.append(float(g))
          print(gt)
          aes.append(rec(pred, gt))
          seq_all.append(pred)
    ns = len(names)
    mic = MIC(seq_all, ns)
    std = STD(seq_all, ns)
    sae = sum(aes)/ns

    print(mic, std, sae)
    x.append(mic)
    y.append(std)
    z.append(sae)
    eval_data["file_names"].append(seq)
    eval_data["mic"].append(mic)
    eval_data["std"].append(std)
    eval_data["ae"].append(sae)

  print(sum(eval_data["mic"])/len(seqs))
  print(sum(eval_data["std"])/len(seqs))
  print(sum(eval_data["ae"])/len(seqs))

  all[num_device[i]].append(sum(eval_data["ae"])/len(seqs))
  all[num_device[i]].append(sum(eval_data["mic"])/len(seqs))
  all[num_device[i]].append(sum(eval_data["std"])/len(seqs))

  path_to_log = "./TAWB/test/cta/logs/tccnet_cta_1700736880.647878"
  pd.DataFrame(eval_data).to_csv(os.path.join(path_to_log, num_device[i]+"_mic.csv"), index=False)  

print(all)
